import os
import torch
import numpy as np


def _resume(args, model, optimizer, replaybuffer):
    # Optionally resume from a chkpt
    to_restore = {"epoch": 0, "best_acc": 0.}
    if args.resume:
        restart_from_chkpt(
            os.path.join(args.output_path, "chkpt_final.pt"),
            run_variables=to_restore,
            state_dict=model,
            optimizer=optimizer,
            replay_buffer=replaybuffer,
        )
    start_epoch = to_restore["epoch"]
    best_acc = to_restore["best_acc"]
    racc_str = "";
    _save_model(args, best_acc, model, optimizer, replaybuffer)
    return start_epoch, best_acc


def _reload(args, loss, model, optimizer, replaybuffer):
    print('reloading model (loss:{})'.format(loss.abs().item()))
    # gpu_show(currentframe())
    try:
        loss.backward()
    except:
        pass
    model.zero_grad()
    restart_from_chkpt(
        args.chkpt_path,
        run_variables=None,
        state_dict=model,
        optimizer=optimizer,
        replay_buffer=replaybuffer,
    )
    model.zero_grad()
    from tool.util import set_seed
    set_seed(np.random.randint(2023))
    # torch.cuda.synchronize()
    args.reloaded = True


def load_model(args, output_path, model=None, is_final=True):
    from model.model_getter import _init_model
    chkpt_name = ['chkpt_init.pth', 'chkpt_final.pth'][is_final]
    path_model = os.path.join(output_path, chkpt_name)
    model = _init_model(args) if model is None else model
    load_pretrained_weights(model, path_model)
    model.cuda(); model.eval()
    return model


def _save_model(args, best_acc, model, optimizer, replay_buffer, epoch=0, clsf=False):
    if args.mcog_hier:
        save_dict = { "epoch": epoch, "best_acc": best_acc,
                  "model_dict": model.state_dict(),
                  "optimizer0": optimizer[0].state_dict(),
                  "optimizer1": optimizer[1].state_dict(),
                  "replay_buffer" : replay_buffer}
    else:
        save_dict = { "epoch": epoch, "best_acc": best_acc,
                  "model_dict": model.state_dict(),
                  "optimizer": optimizer.state_dict(),
                  "replay_buffer" : replay_buffer}
    save_on_master(save_dict, args.chkpt_path)


def save_on_master(*args, **kwargs):
    torch.save(*args, **kwargs)


def load_pretrained_weights(model, model_path):
    chkpt_key = 'model_dict'
    # chkpt_key = 'model_state_dict'
    print(model_path)
    if os.path.isfile(model_path):
        state_dict = torch.load(model_path, map_location="cpu")
        if chkpt_key is not None and chkpt_key in state_dict:
            print(f"Take key {chkpt_key} in provided chkpt dict")
            state_dict = state_dict[chkpt_key]
        new_state_dict = {}
        for k, v in state_dict.items():
            k = k.replace("module.", "")
            new_state_dict[k] = v
        state_dict = new_state_dict
        msg = model.load_state_dict(state_dict, strict=False)
        print('Pretrained weights found at {} and loaded with msg: {}'.format(model_path, msg))


def restart_from_chkpt(ckp_path, run_variables=None, **kwargs):
    """
    Re-start from chkpt
    """
    if not os.path.isfile(ckp_path):
        raise ValueError
    print("Found chkpt at {}".format(ckp_path))

    # open chkpt file
    chkpt = torch.load(ckp_path, map_location="cpu")
    # GPUtil.showUtilization(all=True)
    for key, value in kwargs.items():
        if key in chkpt and value is not None:
            if key is 'replay_buffer':
                kwargs['replay_buffer'][:] = chkpt[key]
                continue
            try:
                if key is not 'optimizer':
                    value = value.cpu()
                    msg = value.load_state_dict(chkpt[key], strict=False)
                    value = value.cuda()
                else:
                    optimizer_to(value)
                    msg = value.load_state_dict(chkpt[key], strict=False)
                    optimizer_to(value, device='cuda')
                print("=> loaded '{}' from chkpt '{}' with msg {}".format(key, ckp_path, msg))
            except TypeError:
                try:
                    if key is not 'optimizer':
                        value = value.cpu()
                        msg = value.load_state_dict(chkpt[key])
                        value = value.cuda()
                    else:
                        optimizer_to(value)
                        msg = value.load_state_dict(chkpt[key])
                        optimizer_to(value, device='cuda')
                    print("=> loaded '{}' from chkpt: '{}'".format(key, ckp_path))
                except ValueError:
                    print("=> failed to load '{}' from chkpt: '{}'".format(key, ckp_path))
            # GPUtil.showUtilization(all=True)
        else:
            print("=> key '{}' not found in chkpt: '{}'".format(key, ckp_path))
    # re load variable important for the run
    # print("debug: check model device")
    if run_variables is not None:
        for var_name in run_variables:
            if var_name in chkpt:
                run_variables[var_name] = chkpt[var_name]
    # GPUtil.showUtilization(all=True)


def optimizer_to(optim, device='cpu'):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

